package com.hy.springboot.common.config;

import org.slf4j.MDC;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;

/**
 * @author lccsetsun
 * @since 2024/1/31 11:08
 **/
public class CustomThreadPoolTaskExecutor extends ThreadPoolTaskExecutor {
	@Override
	public void execute(Runnable task) {
		super.execute(wrap(task));
	}

	@Override
	public Future<?> submit(Runnable task) {
		return super.submit(wrap(task));
	}

	@Override
	public <T> Future<T> submit(Callable<T> task) {
		return super.submit(wrap(task));
	}


	private <T> Callable<T> wrap(final Callable<T> callable) {
		// 获取当前线程的MDC上下文信息
		Map<String, String> context = MDC.getCopyOfContextMap();
		return () -> {
			if (context != null) {
				// 传递给子线程
				MDC.setContextMap(context);
			}
			try {
				return callable.call();
			} finally {
				// 清除MDC上下文信息，避免造成内存泄漏
				MDC.clear();
			}
		};
	}

	private Runnable wrap(final Runnable runnable) {
		Map<String, String> context = MDC.getCopyOfContextMap();
		return () -> {
			if (context != null) {
				MDC.setContextMap(context);
			}
			try {
				runnable.run();
			} finally {
				// 清除MDC上下文信息，避免造成内存泄漏
				MDC.clear();
			}
		};
	}
}
